import math
import sys
from copy import copy
from pathlib import Path

import numpy as np
import pandas as pd
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.cuda import amp

from tfe.detector.yolov7.api.utils.datasets import letterbox
from tfe.detector.yolov7.api.utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
from tfe.detector.yolov7.api.utils.plots import color_list, plot_one_box
from tfe.detector.yolov7.api.utils.torch_utils import time_synchronized

##### basic ####

def autopad(k, p=None):  # kernel, padding
	# Pad to 'same'
	if p is None:
		p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
	return p


class MP(nn.Module):
	def __init__(self, k=2):
		super(MP, self).__init__()
		self.m = nn.MaxPool2d(kernel_size=k, stride=k)

	def forward(self, x):
		return self.m(x)


class SP(nn.Module):
	def __init__(self, k=3, s=1):
		super(SP, self).__init__()
		self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)

	def forward(self, x):
		return self.m(x)


class ReOrg(nn.Module):
	def __init__(self):
		super(ReOrg, self).__init__()

	def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
		return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)


class Concat(nn.Module):
	def __init__(self, dimension=1):
		super(Concat, self).__init__()
		self.d = dimension

	def forward(self, x):
		return torch.cat(x, self.d)


class Chuncat(nn.Module):
	def __init__(self, dimension=1):
		super(Chuncat, self).__init__()
		self.d = dimension

	def forward(self, x):
		x1 = []
		x2 = []
		for xi in x:
			xi1, xi2 = xi.chunk(2, self.d)
			x1.append(xi1)
			x2.append(xi2)
		return torch.cat(x1+x2, self.d)


class Shortcut(nn.Module):
	def __init__(self, dimension=0):
		super(Shortcut, self).__init__()
		self.d = dimension

	def forward(self, x):
		return x[0]+x[1]


class Foldcut(nn.Module):
	def __init__(self, dimension=0):
		super(Foldcut, self).__init__()
		self.d = dimension

	def forward(self, x):
		x1, x2 = x.chunk(2, self.d)
		return x1+x2


class Conv(nn.Module):
	# Standard convolution
	def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
		super(Conv, self).__init__()
		self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
		self.bn = nn.BatchNorm2d(c2)
		self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

	def forward(self, x):
		return self.act(self.bn(self.conv(x)))

	def fuseforward(self, x):
		return self.act(self.conv(x))


class RobustConv(nn.Module):
	# Robust convolution (use high kernel size 7-11 for: downsampling and other layers). Train for 300 - 450 epochs.
	def __init__(self, c1, c2, k=7, s=1, p=None, g=1, act=True, layer_scale_init_value=1e-6):  # ch_in, ch_out, kernel, stride, padding, groups
		super(RobustConv, self).__init__()
		self.conv_dw = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
		self.conv1x1 = nn.Conv2d(c1, c2, 1, 1, 0, groups=1, bias=True)
		self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None

	def forward(self, x):
		x = x.to(memory_format=torch.channels_last)
		x = self.conv1x1(self.conv_dw(x))
		if self.gamma is not None:
			x = x.mul(self.gamma.reshape(1, -1, 1, 1))
		return x


class RobustConv2(nn.Module):
	# Robust convolution 2 (use [32, 5, 2] or [32, 7, 4] or [32, 11, 8] for one of the paths in CSP).
	def __init__(self, c1, c2, k=7, s=4, p=None, g=1, act=True, layer_scale_init_value=1e-6):  # ch_in, ch_out, kernel, stride, padding, groups
		super(RobustConv2, self).__init__()
		self.conv_strided = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
		self.conv_deconv = nn.ConvTranspose2d(in_channels=c1, out_channels=c2, kernel_size=s, stride=s,
											  padding=0, bias=True, dilation=1, groups=1
											  )
		self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None

	def forward(self, x):
		x = self.conv_deconv(self.conv_strided(x))
		if self.gamma is not None:
			x = x.mul(self.gamma.reshape(1, -1, 1, 1))
		return x


def DWConv(c1, c2, k=1, s=1, act=True):
	# Depthwise convolution
	return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)


class GhostConv(nn.Module):
	# Ghost Convolution https://github.com/huawei-noah/ghostnet
	def __init__(self, c1, c2, k=1, s=1, g=1, act=True):  # ch_in, ch_out, kernel, stride, groups
		super(GhostConv, self).__init__()
		c_ = c2 // 2  # hidden channels
		self.cv1 = Conv(c1, c_, k, s, None, g, act)
		self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)

	def forward(self, x):
		y = self.cv1(x)
		return torch.cat([y, self.cv2(y)], 1)


class Stem(nn.Module):
	# Stem
	def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
		super(Stem, self).__init__()
		c_ = int(c2/2)  # hidden channels
		self.cv1 = Conv(c1, c_, 3, 2)
		self.cv2 = Conv(c_, c_, 1, 1)
		self.cv3 = Conv(c_, c_, 3, 2)
		self.pool = torch.nn.MaxPool2d(2, stride=2)
		self.cv4 = Conv(2 * c_, c2, 1, 1)

	def forward(self, x):
		x = self.cv1(x)
		return self.cv4(torch.cat((self.cv3(self.cv2(x)), self.pool(x)), dim=1))


class DownC(nn.Module):
	# Spatial pyramid pooling layer used in YOLOv3-SPP
	def __init__(self, c1, c2, n=1, k=2):
		super(DownC, self).__init__()
		c_ = int(c1)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_, c2//2, 3, k)
		self.cv3 = Conv(c1, c2//2, 1, 1)
		self.mp = nn.MaxPool2d(kernel_size=k, stride=k)

	def forward(self, x):
		return torch.cat((self.cv2(self.cv1(x)), self.cv3(self.mp(x))), dim=1)


class SPP(nn.Module):
	# Spatial pyramid pooling layer used in YOLOv3-SPP
	def __init__(self, c1, c2, k=(5, 9, 13)):
		super(SPP, self).__init__()
		c_ = c1 // 2  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
		self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])

	def forward(self, x):
		x = self.cv1(x)
		return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))


class Bottleneck(nn.Module):
	# Darknet bottleneck
	def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
		super(Bottleneck, self).__init__()
		c_ = int(c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_, c2, 3, 1, g=g)
		self.add = shortcut and c1 == c2

	def forward(self, x):
		return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class Res(nn.Module):
	# ResNet bottleneck
	def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
		super(Res, self).__init__()
		c_ = int(c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_, c_, 3, 1, g=g)
		self.cv3 = Conv(c_, c2, 1, 1)
		self.add = shortcut and c1 == c2

	def forward(self, x):
		return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))


class ResX(Res):
	# ResNet bottleneck
	def __init__(self, c1, c2, shortcut=True, g=32, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
		super().__init__(c1, c2, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels


class Ghost(nn.Module):
	# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
	def __init__(self, c1, c2, k=3, s=1):  # ch_in, ch_out, kernel, stride
		super(Ghost, self).__init__()
		c_ = c2 // 2
		self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1),  # pw
								  DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(),  # dw
								  GhostConv(c_, c2, 1, 1, act=False))  # pw-linear
		self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
									  Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()

	def forward(self, x):
		return self.conv(x) + self.shortcut(x)

##### end of basic #####


##### cspnet #####

class SPPCSPC(nn.Module):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
		super(SPPCSPC, self).__init__()
		c_ = int(2 * c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c1, c_, 1, 1)
		self.cv3 = Conv(c_, c_, 3, 1)
		self.cv4 = Conv(c_, c_, 1, 1)
		self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
		self.cv5 = Conv(4 * c_, c_, 1, 1)
		self.cv6 = Conv(c_, c_, 3, 1)
		self.cv7 = Conv(2 * c_, c2, 1, 1)

	def forward(self, x):
		x1 = self.cv4(self.cv3(self.cv1(x)))
		y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
		y2 = self.cv2(x)
		return self.cv7(torch.cat((y1, y2), dim=1))

class GhostSPPCSPC(SPPCSPC):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
		super().__init__(c1, c2, n, shortcut, g, e, k)
		c_ = int(2 * c2 * e)  # hidden channels
		self.cv1 = GhostConv(c1, c_, 1, 1)
		self.cv2 = GhostConv(c1, c_, 1, 1)
		self.cv3 = GhostConv(c_, c_, 3, 1)
		self.cv4 = GhostConv(c_, c_, 1, 1)
		self.cv5 = GhostConv(4 * c_, c_, 1, 1)
		self.cv6 = GhostConv(c_, c_, 3, 1)
		self.cv7 = GhostConv(2 * c_, c2, 1, 1)


class GhostStem(Stem):
	# Stem
	def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
		super().__init__(c1, c2, k, s, p, g, act)
		c_ = int(c2/2)  # hidden channels
		self.cv1 = GhostConv(c1, c_, 3, 2)
		self.cv2 = GhostConv(c_, c_, 1, 1)
		self.cv3 = GhostConv(c_, c_, 3, 2)
		self.cv4 = GhostConv(2 * c_, c2, 1, 1)


class BottleneckCSPA(nn.Module):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(BottleneckCSPA, self).__init__()
		c_ = int(c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c1, c_, 1, 1)
		self.cv3 = Conv(2 * c_, c2, 1, 1)
		self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		y1 = self.m(self.cv1(x))
		y2 = self.cv2(x)
		return self.cv3(torch.cat((y1, y2), dim=1))


class BottleneckCSPB(nn.Module):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(BottleneckCSPB, self).__init__()
		c_ = int(c2)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_, c_, 1, 1)
		self.cv3 = Conv(2 * c_, c2, 1, 1)
		self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		x1 = self.cv1(x)
		y1 = self.m(x1)
		y2 = self.cv2(x1)
		return self.cv3(torch.cat((y1, y2), dim=1))


class BottleneckCSPC(nn.Module):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(BottleneckCSPC, self).__init__()
		c_ = int(c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c1, c_, 1, 1)
		self.cv3 = Conv(c_, c_, 1, 1)
		self.cv4 = Conv(2 * c_, c2, 1, 1)
		self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		y1 = self.cv3(self.m(self.cv1(x)))
		y2 = self.cv2(x)
		return self.cv4(torch.cat((y1, y2), dim=1))


class ResCSPA(BottleneckCSPA):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])


class ResCSPB(BottleneckCSPB):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2)  # hidden channels
		self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])


class ResCSPC(BottleneckCSPC):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])


class ResXCSPA(ResCSPA):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])


class ResXCSPB(ResCSPB):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2)  # hidden channels
		self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])


class ResXCSPC(ResCSPC):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])


class GhostCSPA(BottleneckCSPA):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])


class GhostCSPB(BottleneckCSPB):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2)  # hidden channels
		self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])


class GhostCSPC(BottleneckCSPC):
	# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])

##### end of cspnet #####


##### yolor #####

class ImplicitA(nn.Module):
	def __init__(self, channel, mean=0., std=.02):
		super(ImplicitA, self).__init__()
		self.channel = channel
		self.mean = mean
		self.std = std
		self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
		nn.init.normal_(self.implicit, mean=self.mean, std=self.std)

	def forward(self, x):
		return self.implicit + x


class ImplicitM(nn.Module):
	def __init__(self, channel, mean=1., std=.02):
		super(ImplicitM, self).__init__()
		self.channel = channel
		self.mean = mean
		self.std = std
		self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
		nn.init.normal_(self.implicit, mean=self.mean, std=self.std)

	def forward(self, x):
		return self.implicit * x

##### end of yolor #####


##### repvgg #####

class RepConv(nn.Module):
	# Represented convolution
	# https://arxiv.org/abs/2101.03697

	def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, deploy=False):
		super(RepConv, self).__init__()

		self.deploy = deploy
		self.groups = g
		self.in_channels = c1
		self.out_channels = c2

		assert k == 3
		assert autopad(k, p) == 1

		padding_11 = autopad(k, p) - k // 2

		self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

		if deploy:
			self.rbr_reparam = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True)

		else:
			self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)

			self.rbr_dense = nn.Sequential(
				nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False),
				nn.BatchNorm2d(num_features=c2),
			)

			self.rbr_1x1 = nn.Sequential(
				nn.Conv2d( c1, c2, 1, s, padding_11, groups=g, bias=False),
				nn.BatchNorm2d(num_features=c2),
			)

	def forward(self, inputs):
		if hasattr(self, "rbr_reparam"):
			return self.act(self.rbr_reparam(inputs))

		if self.rbr_identity is None:
			id_out = 0
		else:
			id_out = self.rbr_identity(inputs)

		return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)

	def get_equivalent_kernel_bias(self):
		kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
		kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
		kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
		return (
			kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
			bias3x3 + bias1x1 + biasid,
		)

	def _pad_1x1_to_3x3_tensor(self, kernel1x1):
		if kernel1x1 is None:
			return 0
		else:
			return nn.functional.pad(kernel1x1, [1, 1, 1, 1])

	def _fuse_bn_tensor(self, branch):
		if branch is None:
			return 0, 0
		if isinstance(branch, nn.Sequential):
			kernel = branch[0].weight
			running_mean = branch[1].running_mean
			running_var = branch[1].running_var
			gamma = branch[1].weight
			beta = branch[1].bias
			eps = branch[1].eps
		else:
			assert isinstance(branch, nn.BatchNorm2d)
			if not hasattr(self, "id_tensor"):
				input_dim = self.in_channels // self.groups
				kernel_value = np.zeros(
					(self.in_channels, input_dim, 3, 3), dtype=np.float32
				)
				for i in range(self.in_channels):
					kernel_value[i, i % input_dim, 1, 1] = 1
				self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
			kernel = self.id_tensor
			running_mean = branch.running_mean
			running_var = branch.running_var
			gamma = branch.weight
			beta = branch.bias
			eps = branch.eps
		std = (running_var + eps).sqrt()
		t = (gamma / std).reshape(-1, 1, 1, 1)
		return kernel * t, beta - running_mean * gamma / std

	def repvgg_convert(self):
		kernel, bias = self.get_equivalent_kernel_bias()
		return (
			kernel.detach().cpu().numpy(),
			bias.detach().cpu().numpy(),
		)

	def fuse_conv_bn(self, conv, bn):

		std = (bn.running_var + bn.eps).sqrt()
		bias = bn.bias - bn.running_mean * bn.weight / std

		t = (bn.weight / std).reshape(-1, 1, 1, 1)
		weights = conv.weight * t

		bn = nn.Identity()
		conv = nn.Conv2d(in_channels = conv.in_channels,
						 out_channels = conv.out_channels,
						 kernel_size = conv.kernel_size,
						 stride=conv.stride,
						 padding = conv.padding,
						 dilation = conv.dilation,
						 groups = conv.groups,
						 bias = True,
						 padding_mode = conv.padding_mode)

		conv.weight = torch.nn.Parameter(weights)
		conv.bias = torch.nn.Parameter(bias)
		return conv

	def fuse_repvgg_block(self):
		if self.deploy:
			return
		print(f"RepConv.fuse_repvgg_block")

		self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])

		self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
		rbr_1x1_bias = self.rbr_1x1.bias
		weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])

		# Fuse self.rbr_identity
		if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
			# print(f"fuse: rbr_identity == BatchNorm2d or SyncBatchNorm")
			identity_conv_1x1 = nn.Conv2d(
				in_channels=self.in_channels,
				out_channels=self.out_channels,
				kernel_size=1,
				stride=1,
				padding=0,
				groups=self.groups,
				bias=False)
			identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
			identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
			# print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
			identity_conv_1x1.weight.data.fill_(0.0)
			identity_conv_1x1.weight.data.fill_diagonal_(1.0)
			identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
			# print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")

			identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
			bias_identity_expanded = identity_conv_1x1.bias
			weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])
		else:
			# print(f"fuse: rbr_identity != BatchNorm2d, rbr_identity = {self.rbr_identity}")
			bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
			weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )


			#print(f"self.rbr_1x1.weight = {self.rbr_1x1.weight.shape}, ")
		#print(f"weight_1x1_expanded = {weight_1x1_expanded.shape}, ")
		#print(f"self.rbr_dense.weight = {self.rbr_dense.weight.shape}, ")

		self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
		self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)

		self.rbr_reparam = self.rbr_dense
		self.deploy = True

		if self.rbr_identity is not None:
			del self.rbr_identity
			self.rbr_identity = None

		if self.rbr_1x1 is not None:
			del self.rbr_1x1
			self.rbr_1x1 = None

		if self.rbr_dense is not None:
			del self.rbr_dense
			self.rbr_dense = None


class RepBottleneck(Bottleneck):
	# Standard bottleneck
	def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
		super().__init__(c1, c2, shortcut=True, g=1, e=0.5)
		c_ = int(c2 * e)  # hidden channels
		self.cv2 = RepConv(c_, c2, 3, 1, g=g)


class RepBottleneckCSPA(BottleneckCSPA):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])


class RepBottleneckCSPB(BottleneckCSPB):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2)  # hidden channels
		self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])


class RepBottleneckCSPC(BottleneckCSPC):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])


class RepRes(Res):
	# Standard bottleneck
	def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
		super().__init__(c1, c2, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.cv2 = RepConv(c_, c_, 3, 1, g=g)


class RepResCSPA(ResCSPA):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])


class RepResCSPB(ResCSPB):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2)  # hidden channels
		self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])


class RepResCSPC(ResCSPC):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])


class RepResX(ResX):
	# Standard bottleneck
	def __init__(self, c1, c2, shortcut=True, g=32, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
		super().__init__(c1, c2, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.cv2 = RepConv(c_, c_, 3, 1, g=g)


class RepResXCSPA(ResXCSPA):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])


class RepResXCSPB(ResXCSPB):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=False, g=32, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2)  # hidden channels
		self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])


class RepResXCSPC(ResXCSPC):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super().__init__(c1, c2, n, shortcut, g, e)
		c_ = int(c2 * e)  # hidden channels
		self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])

##### end of repvgg #####


##### transformer #####

class TransformerLayer(nn.Module):
	# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
	def __init__(self, c, num_heads):
		super().__init__()
		self.q = nn.Linear(c, c, bias=False)
		self.k = nn.Linear(c, c, bias=False)
		self.v = nn.Linear(c, c, bias=False)
		self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
		self.fc1 = nn.Linear(c, c, bias=False)
		self.fc2 = nn.Linear(c, c, bias=False)

	def forward(self, x):
		x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
		x = self.fc2(self.fc1(x)) + x
		return x


class TransformerBlock(nn.Module):
	# Vision Transformer https://arxiv.org/abs/2010.11929
	def __init__(self, c1, c2, num_heads, num_layers):
		super().__init__()
		self.conv = None
		if c1 != c2:
			self.conv = Conv(c1, c2)
		self.linear = nn.Linear(c2, c2)  # learnable position embedding
		self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
		self.c2 = c2

	def forward(self, x):
		if self.conv is not None:
			x = self.conv(x)
		b, _, w, h = x.shape
		p = x.flatten(2)
		p = p.unsqueeze(0)
		p = p.transpose(0, 3)
		p = p.squeeze(3)
		e = self.linear(p)
		x = p + e

		x = self.tr(x)
		x = x.unsqueeze(3)
		x = x.transpose(0, 3)
		x = x.reshape(b, self.c2, w, h)
		return x

##### end of transformer #####


##### yolov5 #####

class Focus(nn.Module):
	# Focus wh information into c-space
	def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
		super(Focus, self).__init__()
		self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
		# self.contract = Contract(gain=2)

	def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
		return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
		# return self.conv(self.contract(x))


class SPPF(nn.Module):
	# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
	def __init__(self, c1, c2, k=5):  # equivalent to SPP(k=(5, 9, 13))
		super().__init__()
		c_ = c1 // 2  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_ * 4, c2, 1, 1)
		self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

	def forward(self, x):
		x = self.cv1(x)
		y1 = self.m(x)
		y2 = self.m(y1)
		return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))


class Contract(nn.Module):
	# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
	def __init__(self, gain=2):
		super().__init__()
		self.gain = gain

	def forward(self, x):
		N, C, H, W = x.size()  # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
		s = self.gain
		x = x.view(N, C, H // s, s, W // s, s)  # x(1,64,40,2,40,2)
		x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # x(1,2,2,64,40,40)
		return x.view(N, C * s * s, H // s, W // s)  # x(1,256,40,40)


class Expand(nn.Module):
	# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
	def __init__(self, gain=2):
		super().__init__()
		self.gain = gain

	def forward(self, x):
		N, C, H, W = x.size()  # assert C / s ** 2 == 0, 'Indivisible gain'
		s = self.gain
		x = x.view(N, s, s, C // s ** 2, H, W)  # x(1,2,2,16,80,80)
		x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # x(1,16,80,2,80,2)
		return x.view(N, C // s ** 2, H * s, W * s)  # x(1,16,160,160)


class NMS(nn.Module):
	# Non-Maximum Suppression (NMS) module
	conf = 0.25  # confidence threshold
	iou = 0.45  # IoU threshold
	classes = None  # (optional list) filter by class

	def __init__(self):
		super(NMS, self).__init__()

	def forward(self, x):
		return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)


class autoShape(nn.Module):
	# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
	conf = 0.25  # NMS confidence threshold
	iou = 0.45  # NMS IoU threshold
	classes = None  # (optional list) filter by class

	def __init__(self, model):
		super(autoShape, self).__init__()
		self.model = model.eval()

	def autoshape(self):
		print('autoShape already enabled, skipping... ')  # model already converted to model.autoshape()
		return self

	@torch.no_grad()
	def forward(self, imgs, size=640, augment=False, profile=False):
		# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
		#   filename:   imgs = 'data/samples/zidane.jpg'
		#   URI:             = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
		#   OpenCV:          = cv2.imread('image.jpg')[:,:,::-1]  # HWC BGR to RGB x(640,1280,3)
		#   PIL:             = Image.open('image.jpg')  # HWC x(640,1280,3)
		#   numpy:           = np.zeros((640,1280,3))  # HWC
		#   torch:           = torch.zeros(16,3,320,640)  # BCHW (scaled to size=640, 0-1 values)
		#   multiple:        = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...]  # list of images

		t = [time_synchronized()]
		p = next(self.model.parameters())  # for device and type
		if isinstance(imgs, torch.Tensor):  # torch
			with amp.autocast(enabled=p.device.type != 'cpu'):
				return self.model(imgs.to(p.device).type_as(p), augment, profile)  # inference

		# Pre-process
		n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs])  # number of images, list of images
		shape0, shape1, files = [], [], []  # image and inference shapes, filenames
		for i, im in enumerate(imgs):
			f = f'image{i}'  # filename
			if isinstance(im, str):  # filename or uri
				im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
			elif isinstance(im, Image.Image):  # PIL Image
				im, f = np.asarray(im), getattr(im, 'filename', f) or f
			files.append(Path(f).with_suffix('.jpg').name)
			if im.shape[0] < 5:  # image in CHW
				im = im.transpose((1, 2, 0))  # reverse dataloader .transpose(2, 0, 1)
			im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3)  # enforce 3ch input
			s = im.shape[:2]  # HWC
			shape0.append(s)  # image shape
			g = (size / max(s))  # gain
			shape1.append([y * g for y in s])
			imgs[i] = im  # update
		shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)]  # inference shape
		x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs]  # pad
		x = np.stack(x, 0) if n > 1 else x[0][None]  # stack
		x = np.ascontiguousarray(x.transpose((0, 3, 1, 2)))  # BHWC to BCHW
		x = torch.from_numpy(x).to(p.device).type_as(p) / 255.  # uint8 to fp16/32
		t.append(time_synchronized())

		with amp.autocast(enabled=p.device.type != 'cpu'):
			# Inference
			y = self.model(x, augment, profile)[0]  # forward
			t.append(time_synchronized())

			# Post-process
			y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)  # NMS
			for i in range(n):
				scale_coords(shape1, y[i][:, :4], shape0[i])

			t.append(time_synchronized())
			return Detections(imgs, y, files, t, self.names, x.shape)


class Detections:
	# detections class for YOLOv5 inference results
	def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
		super(Detections, self).__init__()
		d = pred[0].device  # device
		gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs]  # normalizations
		self.imgs = imgs  # list of images as numpy arrays
		self.pred = pred  # list of tensors pred[0] = (xyxy, conf, cls)
		self.names = names  # class names
		self.files = files  # image filenames
		self.xyxy = pred  # xyxy pixels
		self.xywh = [xyxy2xywh(x) for x in pred]  # xywh pixels
		self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)]  # xyxy normalized
		self.xywhn = [x / g for x, g in zip(self.xywh, gn)]  # xywh normalized
		self.n = len(self.pred)  # number of images (batch size)
		self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3))  # timestamps (ms)
		self.s = shape  # inference BCHW shape

	def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
		colors = color_list()
		for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
			str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
			if pred is not None:
				for c in pred[:, -1].unique():
					n = (pred[:, -1] == c).sum()  # detections per class
					str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string
				if show or save or render:
					for *box, conf, cls in pred:  # xyxy, confidence, class
						label = f'{self.names[int(cls)]} {conf:.2f}'
						plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
			img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img  # from np
			if pprint:
				print(str.rstrip(', '))
			if show:
				img.show(self.files[i])  # show
			if save:
				f = self.files[i]
				img.save(Path(save_dir) / f)  # save
				print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
			if render:
				self.imgs[i] = np.asarray(img)

	def print(self):
		self.display(pprint=True)  # print results
		print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)

	def show(self):
		self.display(show=True)  # show results

	def save(self, save_dir='runs/hub/exp'):
		save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp')  # increment save_dir
		Path(save_dir).mkdir(parents=True, exist_ok=True)
		self.display(save=True, save_dir=save_dir)  # save results

	def render(self):
		self.display(render=True)  # render results
		return self.imgs

	def pandas(self):
		# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
		new = copy(self)  # return copy
		ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name'  # xyxy columns
		cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name'  # xywh columns
		for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
			a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)]  # update
			setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
		return new

	def tolist(self):
		# return a list of Detections objects, i.e. 'for result in results.tolist():'
		x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
		for d in x:
			for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
				setattr(d, k, getattr(d, k)[0])  # pop out of list
		return x

	def __len__(self):
		return self.n


class Classify(nn.Module):
	# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
	def __init__(self, c1, c2, k=1, s=1, p=None, g=1):  # ch_in, ch_out, kernel, stride, padding, groups
		super(Classify, self).__init__()
		self.aap = nn.AdaptiveAvgPool2d(1)  # to x(b,c1,1,1)
		self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g)  # to x(b,c2,1,1)
		self.flat = nn.Flatten()

	def forward(self, x):
		z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1)  # cat if list
		return self.flat(self.conv(z))  # flatten to x(b,c2)

##### end of yolov5 ######


##### orepa #####

def transI_fusebn(kernel, bn):
	gamma = bn.weight
	std = (bn.running_var + bn.eps).sqrt()
	return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std


class ConvBN(nn.Module):
	def __init__(self, in_channels, out_channels, kernel_size,
				 stride=1, padding=0, dilation=1, groups=1, deploy=False, nonlinear=None):
		super().__init__()
		if nonlinear is None:
			self.nonlinear = nn.Identity()
		else:
			self.nonlinear = nonlinear
		if deploy:
			self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
								  stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
		else:
			self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
								  stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
			self.bn = nn.BatchNorm2d(num_features=out_channels)

	def forward(self, x):
		if hasattr(self, 'bn'):
			return self.nonlinear(self.bn(self.conv(x)))
		else:
			return self.nonlinear(self.conv(x))

	def switch_to_deploy(self):
		kernel, bias = transI_fusebn(self.conv.weight, self.bn)
		conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size,
						 stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True)
		conv.weight.data = kernel
		conv.bias.data = bias
		for para in self.parameters():
			para.detach_()
		self.__delattr__('conv')
		self.__delattr__('bn')
		self.conv = conv

class OREPA_3x3_RepConv(nn.Module):

	def __init__(self, in_channels, out_channels, kernel_size,
				 stride=1, padding=0, dilation=1, groups=1,
				 internal_channels_1x1_3x3=None,
				 deploy=False, nonlinear=None, single_init=False):
		super(OREPA_3x3_RepConv, self).__init__()
		self.deploy = deploy

		if nonlinear is None:
			self.nonlinear = nn.Identity()
		else:
			self.nonlinear = nonlinear

		self.kernel_size = kernel_size
		self.in_channels = in_channels
		self.out_channels = out_channels
		self.groups = groups
		assert padding == kernel_size // 2

		self.stride = stride
		self.padding = padding
		self.dilation = dilation

		self.branch_counter = 0

		self.weight_rbr_origin = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), kernel_size, kernel_size))
		nn.init.kaiming_uniform_(self.weight_rbr_origin, a=math.sqrt(1.0))
		self.branch_counter += 1


		if groups < out_channels:
			self.weight_rbr_avg_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
			self.weight_rbr_pfir_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
			nn.init.kaiming_uniform_(self.weight_rbr_avg_conv, a=1.0)
			nn.init.kaiming_uniform_(self.weight_rbr_pfir_conv, a=1.0)
			self.weight_rbr_avg_conv.data
			self.weight_rbr_pfir_conv.data
			self.register_buffer('weight_rbr_avg_avg', torch.ones(kernel_size, kernel_size).mul(1.0/kernel_size/kernel_size))
			self.branch_counter += 1

		else:
			raise NotImplementedError
		self.branch_counter += 1

		if internal_channels_1x1_3x3 is None:
			internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels   # For mobilenet, it is better to have 2X internal channels

		if internal_channels_1x1_3x3 == in_channels:
			self.weight_rbr_1x1_kxk_idconv1 = nn.Parameter(torch.zeros(in_channels, int(in_channels/self.groups), 1, 1))
			id_value = np.zeros((in_channels, int(in_channels/self.groups), 1, 1))
			for i in range(in_channels):
				id_value[i, i % int(in_channels/self.groups), 0, 0] = 1
			id_tensor = torch.from_numpy(id_value).type_as(self.weight_rbr_1x1_kxk_idconv1)
			self.register_buffer('id_tensor', id_tensor)

		else:
			self.weight_rbr_1x1_kxk_conv1 = nn.Parameter(torch.Tensor(internal_channels_1x1_3x3, int(in_channels/self.groups), 1, 1))
			nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv1, a=math.sqrt(1.0))
		self.weight_rbr_1x1_kxk_conv2 = nn.Parameter(torch.Tensor(out_channels, int(internal_channels_1x1_3x3/self.groups), kernel_size, kernel_size))
		nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv2, a=math.sqrt(1.0))
		self.branch_counter += 1

		expand_ratio = 8
		self.weight_rbr_gconv_dw = nn.Parameter(torch.Tensor(in_channels*expand_ratio, 1, kernel_size, kernel_size))
		self.weight_rbr_gconv_pw = nn.Parameter(torch.Tensor(out_channels, in_channels*expand_ratio, 1, 1))
		nn.init.kaiming_uniform_(self.weight_rbr_gconv_dw, a=math.sqrt(1.0))
		nn.init.kaiming_uniform_(self.weight_rbr_gconv_pw, a=math.sqrt(1.0))
		self.branch_counter += 1

		if out_channels == in_channels and stride == 1:
			self.branch_counter += 1

		self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels))
		self.bn = nn.BatchNorm2d(out_channels)

		self.fre_init()

		nn.init.constant_(self.vector[0, :], 0.25)    #origin
		nn.init.constant_(self.vector[1, :], 0.25)      #avg
		nn.init.constant_(self.vector[2, :], 0.0)      #prior
		nn.init.constant_(self.vector[3, :], 0.5)    #1x1_kxk
		nn.init.constant_(self.vector[4, :], 0.5)     #dws_conv


	def fre_init(self):
		prior_tensor = torch.Tensor(self.out_channels, self.kernel_size, self.kernel_size)
		half_fg = self.out_channels/2
		for i in range(self.out_channels):
			for h in range(3):
				for w in range(3):
					if i < half_fg:
						prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3)
					else:
						prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_fg)/3)

		self.register_buffer('weight_rbr_prior', prior_tensor)

	def weight_gen(self):

		weight_rbr_origin = torch.einsum('oihw,o->oihw', self.weight_rbr_origin, self.vector[0, :])

		weight_rbr_avg = torch.einsum('oihw,o->oihw', torch.einsum('oihw,hw->oihw', self.weight_rbr_avg_conv, self.weight_rbr_avg_avg), self.vector[1, :])

		weight_rbr_pfir = torch.einsum('oihw,o->oihw', torch.einsum('oihw,ohw->oihw', self.weight_rbr_pfir_conv, self.weight_rbr_prior), self.vector[2, :])

		weight_rbr_1x1_kxk_conv1 = None
		if hasattr(self, 'weight_rbr_1x1_kxk_idconv1'):
			weight_rbr_1x1_kxk_conv1 = (self.weight_rbr_1x1_kxk_idconv1 + self.id_tensor).squeeze()
		elif hasattr(self, 'weight_rbr_1x1_kxk_conv1'):
			weight_rbr_1x1_kxk_conv1 = self.weight_rbr_1x1_kxk_conv1.squeeze()
		else:
			raise NotImplementedError
		weight_rbr_1x1_kxk_conv2 = self.weight_rbr_1x1_kxk_conv2

		if self.groups > 1:
			g = self.groups
			t, ig = weight_rbr_1x1_kxk_conv1.size()
			o, tg, h, w = weight_rbr_1x1_kxk_conv2.size()
			weight_rbr_1x1_kxk_conv1 = weight_rbr_1x1_kxk_conv1.view(g, int(t/g), ig)
			weight_rbr_1x1_kxk_conv2 = weight_rbr_1x1_kxk_conv2.view(g, int(o/g), tg, h, w)
			weight_rbr_1x1_kxk = torch.einsum('gti,gothw->goihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2).view(o, ig, h, w)
		else:
			weight_rbr_1x1_kxk = torch.einsum('ti,othw->oihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2)

		weight_rbr_1x1_kxk = torch.einsum('oihw,o->oihw', weight_rbr_1x1_kxk, self.vector[3, :])

		weight_rbr_gconv = self.dwsc2full(self.weight_rbr_gconv_dw, self.weight_rbr_gconv_pw, self.in_channels)
		weight_rbr_gconv = torch.einsum('oihw,o->oihw', weight_rbr_gconv, self.vector[4, :])

		weight = weight_rbr_origin + weight_rbr_avg + weight_rbr_1x1_kxk + weight_rbr_pfir + weight_rbr_gconv

		return weight

	def dwsc2full(self, weight_dw, weight_pw, groups):

		t, ig, h, w = weight_dw.size()
		o, _, _, _ = weight_pw.size()
		tg = int(t/groups)
		i = int(ig*groups)
		weight_dw = weight_dw.view(groups, tg, ig, h, w)
		weight_pw = weight_pw.squeeze().view(o, groups, tg)

		weight_dsc = torch.einsum('gtihw,ogt->ogihw', weight_dw, weight_pw)
		return weight_dsc.view(o, i, h, w)

	def forward(self, inputs):
		weight = self.weight_gen()
		out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)

		return self.nonlinear(self.bn(out))

class RepConv_OREPA(nn.Module):

	def __init__(self, c1, c2, k=3, s=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, nonlinear=nn.SiLU()):
		super(RepConv_OREPA, self).__init__()
		self.deploy = deploy
		self.groups = groups
		self.in_channels = c1
		self.out_channels = c2

		self.padding = padding
		self.dilation = dilation
		self.groups = groups

		assert k == 3
		assert padding == 1

		padding_11 = padding - k // 2

		if nonlinear is None:
			self.nonlinearity = nn.Identity()
		else:
			self.nonlinearity = nonlinear

		if use_se:
			self.se = SEBlock(self.out_channels, internal_neurons=self.out_channels // 16)
		else:
			self.se = nn.Identity()

		if deploy:
			self.rbr_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s,
										 padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)

		else:
			self.rbr_identity = nn.BatchNorm2d(num_features=self.in_channels) if self.out_channels == self.in_channels and s == 1 else None
			self.rbr_dense = OREPA_3x3_RepConv(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s, padding=padding, groups=groups, dilation=1)
			self.rbr_1x1 = ConvBN(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=s, padding=padding_11, groups=groups, dilation=1)
			print('RepVGG Block, identity = ', self.rbr_identity)


	def forward(self, inputs):
		if hasattr(self, 'rbr_reparam'):
			return self.nonlinearity(self.se(self.rbr_reparam(inputs)))

		if self.rbr_identity is None:
			id_out = 0
		else:
			id_out = self.rbr_identity(inputs)

		out1 = self.rbr_dense(inputs)
		out2 = self.rbr_1x1(inputs)
		out3 = id_out
		out = out1 + out2 + out3

		return self.nonlinearity(self.se(out))


	#   Optional. This improves the accuracy and facilitates quantization.
	#   1.  Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
	#   2.  Use like this.
	#       loss = criterion(....)
	#       for every RepVGGBlock blk:
	#           loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
	#       optimizer.zero_grad()
	#       loss.backward()

	# Not used for OREPA
	def get_custom_L2(self):
		K3 = self.rbr_dense.weight_gen()
		K1 = self.rbr_1x1.conv.weight
		t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
		t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()

		l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum()      # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
		eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1                           # The equivalent resultant central point of 3x3 kernel.
		l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum()        # Normalize for an L2 coefficient comparable to regular L2.
		return l2_loss_eq_kernel + l2_loss_circle

	def get_equivalent_kernel_bias(self):
		kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
		kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
		kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
		return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

	def _pad_1x1_to_3x3_tensor(self, kernel1x1):
		if kernel1x1 is None:
			return 0
		else:
			return torch.nn.functional.pad(kernel1x1, [1,1,1,1])

	def _fuse_bn_tensor(self, branch):
		if branch is None:
			return 0, 0
		if not isinstance(branch, nn.BatchNorm2d):
			if isinstance(branch, OREPA_3x3_RepConv):
				kernel = branch.weight_gen()
			elif isinstance(branch, ConvBN):
				kernel = branch.conv.weight
			else:
				raise NotImplementedError
			running_mean = branch.bn.running_mean
			running_var = branch.bn.running_var
			gamma = branch.bn.weight
			beta = branch.bn.bias
			eps = branch.bn.eps
		else:
			if not hasattr(self, 'id_tensor'):
				input_dim = self.in_channels // self.groups
				kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
				for i in range(self.in_channels):
					kernel_value[i, i % input_dim, 1, 1] = 1
				self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
			kernel = self.id_tensor
			running_mean = branch.running_mean
			running_var = branch.running_var
			gamma = branch.weight
			beta = branch.bias
			eps = branch.eps
		std = (running_var + eps).sqrt()
		t = (gamma / std).reshape(-1, 1, 1, 1)
		return kernel * t, beta - running_mean * gamma / std

	def switch_to_deploy(self):
		if hasattr(self, 'rbr_reparam'):
			return
		print(f"RepConv_OREPA.switch_to_deploy")
		kernel, bias = self.get_equivalent_kernel_bias()
		self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.in_channels, out_channels=self.rbr_dense.out_channels,
									 kernel_size=self.rbr_dense.kernel_size, stride=self.rbr_dense.stride,
									 padding=self.rbr_dense.padding, dilation=self.rbr_dense.dilation, groups=self.rbr_dense.groups, bias=True)
		self.rbr_reparam.weight.data = kernel
		self.rbr_reparam.bias.data = bias
		for para in self.parameters():
			para.detach_()
		self.__delattr__('rbr_dense')
		self.__delattr__('rbr_1x1')
		if hasattr(self, 'rbr_identity'):
			self.__delattr__('rbr_identity')

		##### end of orepa #####


##### swin transformer #####    

class WindowAttention(nn.Module):

	def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

		super().__init__()
		self.dim = dim
		self.window_size = window_size  # Wh, Ww
		self.num_heads = num_heads
		head_dim = dim // num_heads
		self.scale = qk_scale or head_dim ** -0.5

		# define a parameter table of relative position bias
		self.relative_position_bias_table = nn.Parameter(
			torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

		# get pair-wise relative position index for each token inside the window
		coords_h = torch.arange(self.window_size[0])
		coords_w = torch.arange(self.window_size[1])
		coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
		coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
		relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
		relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
		relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
		relative_coords[:, :, 1] += self.window_size[1] - 1
		relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
		relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
		self.register_buffer("relative_position_index", relative_position_index)

		self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
		self.attn_drop = nn.Dropout(attn_drop)
		self.proj = nn.Linear(dim, dim)
		self.proj_drop = nn.Dropout(proj_drop)

		nn.init.normal_(self.relative_position_bias_table, std=.02)
		self.softmax = nn.Softmax(dim=-1)

	def forward(self, x, mask=None):

		B_, N, C = x.shape
		qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
		q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

		q = q * self.scale
		attn = (q @ k.transpose(-2, -1))

		relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
			self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
		relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
		attn = attn + relative_position_bias.unsqueeze(0)

		if mask is not None:
			nW = mask.shape[0]
			attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
			attn = attn.view(-1, self.num_heads, N, N)
			attn = self.softmax(attn)
		else:
			attn = self.softmax(attn)

		attn = self.attn_drop(attn)

		# print(attn.dtype, v.dtype)
		try:
			x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
		except:
			#print(attn.dtype, v.dtype)
			x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
		x = self.proj(x)
		x = self.proj_drop(x)
		return x

class Mlp(nn.Module):

	def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
		super().__init__()
		out_features = out_features or in_features
		hidden_features = hidden_features or in_features
		self.fc1 = nn.Linear(in_features, hidden_features)
		self.act = act_layer()
		self.fc2 = nn.Linear(hidden_features, out_features)
		self.drop = nn.Dropout(drop)

	def forward(self, x):
		x = self.fc1(x)
		x = self.act(x)
		x = self.drop(x)
		x = self.fc2(x)
		x = self.drop(x)
		return x

def window_partition(x, window_size):

	B, H, W, C = x.shape
	assert H % window_size == 0, 'feature map h and w can not divide by window size'
	x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
	windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
	return windows

def window_reverse(windows, window_size, H, W):

	B = int(windows.shape[0] / (H * W / window_size / window_size))
	x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
	x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
	return x


class SwinTransformerLayer(nn.Module):

	def __init__(self, dim, num_heads, window_size=8, shift_size=0,
				 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
				 act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
		super().__init__()
		self.dim = dim
		self.num_heads = num_heads
		self.window_size = window_size
		self.shift_size = shift_size
		self.mlp_ratio = mlp_ratio
		# if min(self.input_resolution) <= self.window_size:
		#     # if window size is larger than input resolution, we don't partition windows
		#     self.shift_size = 0
		#     self.window_size = min(self.input_resolution)
		assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

		self.norm1 = norm_layer(dim)
		self.attn = WindowAttention(
			dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
			qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

		self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
		self.norm2 = norm_layer(dim)
		mlp_hidden_dim = int(dim * mlp_ratio)
		self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

	def create_mask(self, H, W):
		# calculate attention mask for SW-MSA
		img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
		h_slices = (slice(0, -self.window_size),
					slice(-self.window_size, -self.shift_size),
					slice(-self.shift_size, None))
		w_slices = (slice(0, -self.window_size),
					slice(-self.window_size, -self.shift_size),
					slice(-self.shift_size, None))
		cnt = 0
		for h in h_slices:
			for w in w_slices:
				img_mask[:, h, w, :] = cnt
				cnt += 1

		mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
		mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
		attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
		attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

		return attn_mask

	def forward(self, x):
		# reshape x[b c h w] to x[b l c]
		_, _, H_, W_ = x.shape

		Padding = False
		if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
			Padding = True
			# print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
			pad_r = (self.window_size - W_ % self.window_size) % self.window_size
			pad_b = (self.window_size - H_ % self.window_size) % self.window_size
			x = F.pad(x, (0, pad_r, 0, pad_b))

		# print('2', x.shape)
		B, C, H, W = x.shape
		L = H * W
		x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)  # b, L, c

		# create mask from init to forward
		if self.shift_size > 0:
			attn_mask = self.create_mask(H, W).to(x.device)
		else:
			attn_mask = None

		shortcut = x
		x = self.norm1(x)
		x = x.view(B, H, W, C)

		# cyclic shift
		if self.shift_size > 0:
			shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
		else:
			shifted_x = x

		# partition windows
		x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
		x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

		# W-MSA/SW-MSA
		attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

		# merge windows
		attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
		shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

		# reverse cyclic shift
		if self.shift_size > 0:
			x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
		else:
			x = shifted_x
		x = x.view(B, H * W, C)

		# FFN
		x = shortcut + self.drop_path(x)
		x = x + self.drop_path(self.mlp(self.norm2(x)))

		x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W)  # b c h w

		if Padding:
			x = x[:, :, :H_, :W_]  # reverse padding

		return x


class SwinTransformerBlock(nn.Module):
	def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
		super().__init__()
		self.conv = None
		if c1 != c2:
			self.conv = Conv(c1, c2)

		# remove input_resolution
		self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
														   shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])

	def forward(self, x):
		if self.conv is not None:
			x = self.conv(x)
		x = self.blocks(x)
		return x


class STCSPA(nn.Module):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(STCSPA, self).__init__()
		c_ = int(c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c1, c_, 1, 1)
		self.cv3 = Conv(2 * c_, c2, 1, 1)
		num_heads = c_ // 32
		self.m = SwinTransformerBlock(c_, c_, num_heads, n)
		#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		y1 = self.m(self.cv1(x))
		y2 = self.cv2(x)
		return self.cv3(torch.cat((y1, y2), dim=1))


class STCSPB(nn.Module):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(STCSPB, self).__init__()
		c_ = int(c2)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_, c_, 1, 1)
		self.cv3 = Conv(2 * c_, c2, 1, 1)
		num_heads = c_ // 32
		self.m = SwinTransformerBlock(c_, c_, num_heads, n)
		#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		x1 = self.cv1(x)
		y1 = self.m(x1)
		y2 = self.cv2(x1)
		return self.cv3(torch.cat((y1, y2), dim=1))


class STCSPC(nn.Module):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(STCSPC, self).__init__()
		c_ = int(c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c1, c_, 1, 1)
		self.cv3 = Conv(c_, c_, 1, 1)
		self.cv4 = Conv(2 * c_, c2, 1, 1)
		num_heads = c_ // 32
		self.m = SwinTransformerBlock(c_, c_, num_heads, n)
		#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		y1 = self.cv3(self.m(self.cv1(x)))
		y2 = self.cv2(x)
		return self.cv4(torch.cat((y1, y2), dim=1))

##### end of swin transformer #####   


##### swin transformer v2 ##### 

class WindowAttention_v2(nn.Module):

	def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
				 pretrained_window_size=[0, 0]):

		super().__init__()
		self.dim = dim
		self.window_size = window_size  # Wh, Ww
		self.pretrained_window_size = pretrained_window_size
		self.num_heads = num_heads

		self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)

		# mlp to generate continuous relative position bias
		self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
									 nn.ReLU(inplace=True),
									 nn.Linear(512, num_heads, bias=False))

		# get relative_coords_table
		relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
		relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
		relative_coords_table = torch.stack(
			torch.meshgrid([relative_coords_h,
							relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
		if pretrained_window_size[0] > 0:
			relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
			relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
		else:
			relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
			relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
		relative_coords_table *= 8  # normalize to -8, 8
		relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
			torch.abs(relative_coords_table) + 1.0) / np.log2(8)

		self.register_buffer("relative_coords_table", relative_coords_table)

		# get pair-wise relative position index for each token inside the window
		coords_h = torch.arange(self.window_size[0])
		coords_w = torch.arange(self.window_size[1])
		coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
		coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
		relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
		relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
		relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
		relative_coords[:, :, 1] += self.window_size[1] - 1
		relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
		relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
		self.register_buffer("relative_position_index", relative_position_index)

		self.qkv = nn.Linear(dim, dim * 3, bias=False)
		if qkv_bias:
			self.q_bias = nn.Parameter(torch.zeros(dim))
			self.v_bias = nn.Parameter(torch.zeros(dim))
		else:
			self.q_bias = None
			self.v_bias = None
		self.attn_drop = nn.Dropout(attn_drop)
		self.proj = nn.Linear(dim, dim)
		self.proj_drop = nn.Dropout(proj_drop)
		self.softmax = nn.Softmax(dim=-1)

	def forward(self, x, mask=None):

		B_, N, C = x.shape
		qkv_bias = None
		if self.q_bias is not None:
			qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
		qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
		qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
		q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

		# cosine attention
		attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
		logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
		attn = attn * logit_scale

		relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
		relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
			self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
		relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
		relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
		attn = attn + relative_position_bias.unsqueeze(0)

		if mask is not None:
			nW = mask.shape[0]
			attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
			attn = attn.view(-1, self.num_heads, N, N)
			attn = self.softmax(attn)
		else:
			attn = self.softmax(attn)

		attn = self.attn_drop(attn)

		try:
			x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
		except:
			x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)

		x = self.proj(x)
		x = self.proj_drop(x)
		return x

	def extra_repr(self) -> str:
		return f'dim={self.dim}, window_size={self.window_size}, ' \
			   f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'

	def flops(self, N):
		# calculate flops for 1 window with token length of N
		flops = 0
		# qkv = self.qkv(x)
		flops += N * self.dim * 3 * self.dim
		# attn = (q @ k.transpose(-2, -1))
		flops += self.num_heads * N * (self.dim // self.num_heads) * N
		#  x = (attn @ v)
		flops += self.num_heads * N * N * (self.dim // self.num_heads)
		# x = self.proj(x)
		flops += N * self.dim * self.dim
		return flops

class Mlp_v2(nn.Module):
	def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
		super().__init__()
		out_features = out_features or in_features
		hidden_features = hidden_features or in_features
		self.fc1 = nn.Linear(in_features, hidden_features)
		self.act = act_layer()
		self.fc2 = nn.Linear(hidden_features, out_features)
		self.drop = nn.Dropout(drop)

	def forward(self, x):
		x = self.fc1(x)
		x = self.act(x)
		x = self.drop(x)
		x = self.fc2(x)
		x = self.drop(x)
		return x


def window_partition_v2(x, window_size):

	B, H, W, C = x.shape
	x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
	windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
	return windows


def window_reverse_v2(windows, window_size, H, W):

	B = int(windows.shape[0] / (H * W / window_size / window_size))
	x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
	x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
	return x


class SwinTransformerLayer_v2(nn.Module):

	def __init__(self, dim, num_heads, window_size=7, shift_size=0,
				 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
				 act_layer=nn.SiLU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
		super().__init__()
		self.dim = dim
		#self.input_resolution = input_resolution
		self.num_heads = num_heads
		self.window_size = window_size
		self.shift_size = shift_size
		self.mlp_ratio = mlp_ratio
		#if min(self.input_resolution) <= self.window_size:
		#    # if window size is larger than input resolution, we don't partition windows
		#    self.shift_size = 0
		#    self.window_size = min(self.input_resolution)
		assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

		self.norm1 = norm_layer(dim)
		self.attn = WindowAttention_v2(
			dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
			qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
			pretrained_window_size=(pretrained_window_size, pretrained_window_size))

		self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
		self.norm2 = norm_layer(dim)
		mlp_hidden_dim = int(dim * mlp_ratio)
		self.mlp = Mlp_v2(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

	def create_mask(self, H, W):
		# calculate attention mask for SW-MSA
		img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
		h_slices = (slice(0, -self.window_size),
					slice(-self.window_size, -self.shift_size),
					slice(-self.shift_size, None))
		w_slices = (slice(0, -self.window_size),
					slice(-self.window_size, -self.shift_size),
					slice(-self.shift_size, None))
		cnt = 0
		for h in h_slices:
			for w in w_slices:
				img_mask[:, h, w, :] = cnt
				cnt += 1

		mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
		mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
		attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
		attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

		return attn_mask

	def forward(self, x):
		# reshape x[b c h w] to x[b l c]
		_, _, H_, W_ = x.shape

		Padding = False
		if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
			Padding = True
			# print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
			pad_r = (self.window_size - W_ % self.window_size) % self.window_size
			pad_b = (self.window_size - H_ % self.window_size) % self.window_size
			x = F.pad(x, (0, pad_r, 0, pad_b))

		# print('2', x.shape)
		B, C, H, W = x.shape
		L = H * W
		x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)  # b, L, c

		# create mask from init to forward
		if self.shift_size > 0:
			attn_mask = self.create_mask(H, W).to(x.device)
		else:
			attn_mask = None

		shortcut = x
		x = x.view(B, H, W, C)

		# cyclic shift
		if self.shift_size > 0:
			shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
		else:
			shifted_x = x

		# partition windows
		x_windows = window_partition_v2(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
		x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

		# W-MSA/SW-MSA
		attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

		# merge windows
		attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
		shifted_x = window_reverse_v2(attn_windows, self.window_size, H, W)  # B H' W' C

		# reverse cyclic shift
		if self.shift_size > 0:
			x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
		else:
			x = shifted_x
		x = x.view(B, H * W, C)
		x = shortcut + self.drop_path(self.norm1(x))

		# FFN
		x = x + self.drop_path(self.norm2(self.mlp(x)))
		x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W)  # b c h w

		if Padding:
			x = x[:, :, :H_, :W_]  # reverse padding

		return x

	def extra_repr(self) -> str:
		return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
			   f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

	def flops(self):
		flops = 0
		H, W = self.input_resolution
		# norm1
		flops += self.dim * H * W
		# W-MSA/SW-MSA
		nW = H * W / self.window_size / self.window_size
		flops += nW * self.attn.flops(self.window_size * self.window_size)
		# mlp
		flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
		# norm2
		flops += self.dim * H * W
		return flops


class SwinTransformer2Block(nn.Module):
	def __init__(self, c1, c2, num_heads, num_layers, window_size=7):
		super().__init__()
		self.conv = None
		if c1 != c2:
			self.conv = Conv(c1, c2)

		# remove input_resolution
		self.blocks = nn.Sequential(*[SwinTransformerLayer_v2(dim=c2, num_heads=num_heads, window_size=window_size,
															  shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])

	def forward(self, x):
		if self.conv is not None:
			x = self.conv(x)
		x = self.blocks(x)
		return x


class ST2CSPA(nn.Module):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(ST2CSPA, self).__init__()
		c_ = int(c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c1, c_, 1, 1)
		self.cv3 = Conv(2 * c_, c2, 1, 1)
		num_heads = c_ // 32
		self.m = SwinTransformer2Block(c_, c_, num_heads, n)
		#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		y1 = self.m(self.cv1(x))
		y2 = self.cv2(x)
		return self.cv3(torch.cat((y1, y2), dim=1))


class ST2CSPB(nn.Module):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(ST2CSPB, self).__init__()
		c_ = int(c2)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c_, c_, 1, 1)
		self.cv3 = Conv(2 * c_, c2, 1, 1)
		num_heads = c_ // 32
		self.m = SwinTransformer2Block(c_, c_, num_heads, n)
		#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		x1 = self.cv1(x)
		y1 = self.m(x1)
		y2 = self.cv2(x1)
		return self.cv3(torch.cat((y1, y2), dim=1))


class ST2CSPC(nn.Module):
	# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
	def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
		super(ST2CSPC, self).__init__()
		c_ = int(c2 * e)  # hidden channels
		self.cv1 = Conv(c1, c_, 1, 1)
		self.cv2 = Conv(c1, c_, 1, 1)
		self.cv3 = Conv(c_, c_, 1, 1)
		self.cv4 = Conv(2 * c_, c2, 1, 1)
		num_heads = c_ // 32
		self.m = SwinTransformer2Block(c_, c_, num_heads, n)
		#self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

	def forward(self, x):
		y1 = self.cv3(self.m(self.cv1(x)))
		y2 = self.cv2(x)
		return self.cv4(torch.cat((y1, y2), dim=1))

##### end of swin transformer v2 #####   
